{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import jieba\n",
    "import time\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import tensorflow as tf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import MultiLabelBinarizer\n",
    "from sklearn.multiclass import OneVsRestClassifier\n",
    "from sklearn.naive_bayes import MultinomialNB,GaussianNB\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.svm import SVC\n",
    "from sklearn.neural_network import MLPClassifier\n",
    "from sklearn.naive_bayes import ComplementNB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "root = './data/cleaning_data_95.csv'\n",
    "data = pd.read_csv(root)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 95分类 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "词汇表为：\n",
      " 20000\n",
      "Wall time: 24.8 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "corpus=data['item']\n",
    "\n",
    "from sklearn.feature_extraction.text import TfidfVectorizer\n",
    "tfidf_vect = TfidfVectorizer(max_features=20000,min_df=3,ngram_range=(2,4))\n",
    "vec = tfidf_vect.fit_transform(corpus)\n",
    "print('词汇表为：\\n', len(tfidf_vect.vocabulary_))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.preprocessing import MultiLabelBinarizer\n",
    "mlb = MultiLabelBinarizer()\n",
    "y = mlb.fit_transform(data.labels.apply(lambda x: x.split()))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "95"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(y[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split\n",
    "x_train, x_test, y_train, y_test = train_test_split(vec, y, test_size=0.2, random_state=2020)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4516"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(y_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### LR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "clf = OneVsRestClassifier(LogisticRegression())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Program Files\\Anaconda3\\lib\\site-packages\\sklearn\\multiclass.py:75: UserWarning: Label 93 is present in all training examples.\n",
      "  str(classes[c]))\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "OneVsRestClassifier(estimator=LogisticRegression(C=1.0, class_weight=None,\n",
       "                                                 dual=False, fit_intercept=True,\n",
       "                                                 intercept_scaling=1,\n",
       "                                                 l1_ratio=None, max_iter=100,\n",
       "                                                 multi_class='auto',\n",
       "                                                 n_jobs=None, penalty='l2',\n",
       "                                                 random_state=None,\n",
       "                                                 solver='lbfgs', tol=0.0001,\n",
       "                                                 verbose=0, warm_start=False),\n",
       "                    n_jobs=None)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "clf.fit(x_train,y_train)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "pred = clf.predict(x_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.metrics import classification_report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.00      0.00      0.00       184\n",
      "           1       0.00      0.00      0.00        94\n",
      "           2       0.00      0.00      0.00       143\n",
      "           3       0.00      0.00      0.00       184\n",
      "           4       0.00      0.00      0.00       220\n",
      "           5       0.00      0.00      0.00       140\n",
      "           6       0.00      0.00      0.00        57\n",
      "           7       0.00      0.00      0.00        57\n",
      "           8       0.00      0.00      0.00       135\n",
      "           9       0.00      0.00      0.00        64\n",
      "          10       0.00      0.00      0.00        78\n",
      "          11       0.99      0.14      0.25       883\n",
      "          12       0.00      0.00      0.00       108\n",
      "          13       0.00      0.00      0.00       140\n",
      "          14       0.00      0.00      0.00       140\n",
      "          15       0.00      0.00      0.00       140\n",
      "          16       0.98      0.16      0.27       277\n",
      "          17       0.00      0.00      0.00        76\n",
      "          18       0.00      0.00      0.00        76\n",
      "          19       1.00      0.01      0.01       290\n",
      "          20       0.00      0.00      0.00       220\n",
      "          21       0.00      0.00      0.00        83\n",
      "          22       1.00      0.04      0.08       282\n",
      "          23       0.00      0.00      0.00       166\n",
      "          24       0.00      0.00      0.00       502\n",
      "          25       1.00      0.01      0.01       175\n",
      "          26       0.87      0.05      0.10       532\n",
      "          27       0.00      0.00      0.00       180\n",
      "          28       1.00      0.04      0.08        71\n",
      "          29       0.00      0.00      0.00         9\n",
      "          30       0.00      0.00      0.00        84\n",
      "          31       0.00      0.00      0.00        57\n",
      "          32       0.00      0.00      0.00        57\n",
      "          33       0.50      0.00      0.01       334\n",
      "          34       0.00      0.00      0.00       124\n",
      "          35       1.00      0.04      0.07       978\n",
      "          36       0.00      0.00      0.00        97\n",
      "          37       1.00      0.05      0.10       187\n",
      "          38       0.92      0.15      0.25       164\n",
      "          39       0.92      0.15      0.25       164\n",
      "          40       0.00      0.00      0.00       177\n",
      "          41       0.84      0.09      0.16       182\n",
      "          42       0.96      0.14      0.25       188\n",
      "          43       0.00      0.00      0.00       184\n",
      "          44       0.00      0.00      0.00       126\n",
      "          45       1.00      0.02      0.03       694\n",
      "          46       0.00      0.00      0.00        64\n",
      "          47       0.00      0.00      0.00       102\n",
      "          48       0.98      0.15      0.25       371\n",
      "          49       0.00      0.00      0.00       110\n",
      "          50       0.00      0.00      0.00       220\n",
      "          51       0.00      0.00      0.00        66\n",
      "          52       0.00      0.00      0.00       220\n",
      "          53       0.00      0.00      0.00        97\n",
      "          54       0.00      0.00      0.00       220\n",
      "          55       0.00      0.00      0.00        99\n",
      "          56       0.00      0.00      0.00       190\n",
      "          57       0.36      0.02      0.04       211\n",
      "          58       1.00      0.01      0.01       140\n",
      "          59       0.00      0.00      0.00       183\n",
      "          60       0.00      0.00      0.00        60\n",
      "          61       0.00      0.00      0.00         2\n",
      "          62       0.72      1.00      0.83      2635\n",
      "          63       1.00      0.08      0.14       171\n",
      "          64       0.99      0.14      0.25       883\n",
      "          65       0.99      0.16      0.28       427\n",
      "          66       0.58      0.06      0.11       321\n",
      "          67       0.81      0.06      0.12       797\n",
      "          68       0.00      0.00      0.00       184\n",
      "          69       0.98      0.16      0.27       277\n",
      "          70       1.00      0.01      0.02        92\n",
      "          71       0.00      0.00      0.00        74\n",
      "          72       1.00      0.01      0.02        92\n",
      "          73       0.88      0.01      0.02       625\n",
      "          74       0.00      0.00      0.00        97\n",
      "          75       0.00      0.00      0.00       215\n",
      "          76       0.00      0.00      0.00       215\n",
      "          77       0.00      0.00      0.00        65\n",
      "          78       0.00      0.00      0.00        66\n",
      "          79       0.00      0.00      0.00        65\n",
      "          80       0.00      0.00      0.00         0\n",
      "          81       0.00      0.00      0.00        66\n",
      "          82       1.00      0.04      0.08        71\n",
      "          83       0.00      0.00      0.00        99\n",
      "          84       0.00      0.00      0.00        57\n",
      "          85       0.00      0.00      0.00       121\n",
      "          86       0.00      0.00      0.00       162\n",
      "          87       0.00      0.00      0.00        69\n",
      "          88       0.00      0.00      0.00       179\n",
      "          89       0.98      0.08      0.15       507\n",
      "          90       1.00      0.08      0.15       519\n",
      "          91       0.99      0.14      0.25       883\n",
      "          92       0.00      0.00      0.00       184\n",
      "          93       1.00      1.00      1.00      4516\n",
      "          94       0.00      0.00      0.00       220\n",
      "\n",
      "   micro avg       0.88      0.30      0.45     26812\n",
      "   macro avg       0.32      0.05      0.06     26812\n",
      "weighted avg       0.65      0.30      0.32     26812\n",
      " samples avg       0.88      0.32      0.45     26812\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Program Files\\Anaconda3\\lib\\site-packages\\sklearn\\metrics\\_classification.py:1272: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n",
      "D:\\Program Files\\Anaconda3\\lib\\site-packages\\sklearn\\metrics\\_classification.py:1272: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    }
   ],
   "source": [
    "print(classification_report(y_test,pred))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 高斯Bayes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "# nb = OneVsRestClassifier(GaussianNB())\n",
    "# nb.fit(x_train.toarray(),y_train)  \n",
    "# prd1 = nb.predict(x_test.toarray())\n",
    "# print(classification_report(y_test,prd1))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**超出内存**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Program Files\\Anaconda3\\lib\\site-packages\\sklearn\\multiclass.py:75: UserWarning: Label 93 is present in all training examples.\n",
      "  str(classes[c]))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.00      0.00      0.00       184\n",
      "           1       0.01      0.01      0.01        94\n",
      "           2       0.01      0.01      0.01       143\n",
      "           3       0.00      0.00      0.00       184\n",
      "           4       0.01      0.00      0.01       220\n",
      "           5       0.05      0.03      0.04       140\n",
      "           6       0.00      0.00      0.00        57\n",
      "           7       0.00      0.00      0.00        57\n",
      "           8       0.03      0.01      0.02       135\n",
      "           9       0.02      0.03      0.03        64\n",
      "          10       0.00      0.00      0.00        78\n",
      "          11       0.96      0.40      0.57       883\n",
      "          12       0.08      0.06      0.07       108\n",
      "          13       0.05      0.03      0.04       140\n",
      "          14       0.05      0.03      0.04       140\n",
      "          15       0.05      0.03      0.04       140\n",
      "          16       0.60      0.27      0.37       277\n",
      "          17       0.00      0.00      0.00        76\n",
      "          18       0.00      0.00      0.00        76\n",
      "          19       0.14      0.03      0.05       290\n",
      "          20       0.01      0.00      0.01       220\n",
      "          21       0.02      0.02      0.02        83\n",
      "          22       0.36      0.13      0.19       282\n",
      "          23       0.08      0.04      0.05       166\n",
      "          24       0.15      0.02      0.03       502\n",
      "          25       0.21      0.12      0.15       175\n",
      "          26       0.05      0.01      0.01       532\n",
      "          27       0.00      0.00      0.00       180\n",
      "          28       0.14      0.18      0.16        71\n",
      "          29       0.00      0.00      0.00         9\n",
      "          30       0.00      0.00      0.00        84\n",
      "          31       0.00      0.00      0.00        57\n",
      "          32       0.00      0.00      0.00        57\n",
      "          33       0.08      0.02      0.03       334\n",
      "          34       0.00      0.00      0.00       124\n",
      "          35       0.93      0.21      0.34       978\n",
      "          36       0.00      0.00      0.00        97\n",
      "          37       0.29      0.16      0.21       187\n",
      "          38       0.41      0.29      0.34       164\n",
      "          39       0.41      0.29      0.34       164\n",
      "          40       0.19      0.10      0.13       177\n",
      "          41       0.28      0.14      0.19       182\n",
      "          42       0.46      0.29      0.35       188\n",
      "          43       0.00      0.00      0.00       184\n",
      "          44       0.00      0.00      0.00       126\n",
      "          45       0.67      0.09      0.16       694\n",
      "          46       0.01      0.02      0.01        64\n",
      "          47       0.05      0.04      0.04       102\n",
      "          48       0.70      0.26      0.38       371\n",
      "          49       0.01      0.01      0.01       110\n",
      "          50       0.01      0.00      0.01       220\n",
      "          51       0.00      0.00      0.00        66\n",
      "          52       0.01      0.00      0.01       220\n",
      "          53       0.00      0.00      0.00        97\n",
      "          54       0.01      0.00      0.01       220\n",
      "          55       0.03      0.02      0.02        99\n",
      "          56       0.00      0.00      0.00       190\n",
      "          57       0.26      0.13      0.17       211\n",
      "          58       0.05      0.04      0.04       140\n",
      "          59       0.08      0.03      0.05       183\n",
      "          60       0.02      0.03      0.03        60\n",
      "          61       0.00      0.00      0.00         2\n",
      "          62       0.90      1.00      0.95      2635\n",
      "          63       0.34      0.21      0.26       171\n",
      "          64       0.96      0.40      0.57       883\n",
      "          65       0.72      0.29      0.41       427\n",
      "          66       0.48      0.22      0.30       321\n",
      "          67       0.74      0.25      0.37       797\n",
      "          68       0.00      0.00      0.00       184\n",
      "          69       0.60      0.27      0.37       277\n",
      "          70       0.04      0.04      0.04        92\n",
      "          71       0.00      0.00      0.00        74\n",
      "          72       0.04      0.04      0.04        92\n",
      "          73       0.64      0.13      0.21       625\n",
      "          74       0.00      0.00      0.00        97\n",
      "          75       0.03      0.01      0.01       215\n",
      "          76       0.03      0.01      0.01       215\n",
      "          77       0.00      0.00      0.00        65\n",
      "          78       0.00      0.00      0.00        66\n",
      "          79       0.00      0.00      0.00        65\n",
      "          80       0.00      0.00      0.00         0\n",
      "          81       0.01      0.02      0.01        66\n",
      "          82       0.14      0.18      0.16        71\n",
      "          83       0.00      0.00      0.00        99\n",
      "          84       0.00      0.00      0.00        57\n",
      "          85       0.00      0.00      0.00       121\n",
      "          86       0.00      0.00      0.00       162\n",
      "          87       0.00      0.00      0.00        69\n",
      "          88       0.07      0.03      0.05       179\n",
      "          89       0.75      0.24      0.37       507\n",
      "          90       0.77      0.24      0.37       519\n",
      "          91       0.96      0.40      0.57       883\n",
      "          92       0.00      0.00      0.00       184\n",
      "          93       1.00      1.00      1.00      4516\n",
      "          94       0.01      0.00      0.01       220\n",
      "\n",
      "   micro avg       0.59      0.37      0.46     26812\n",
      "   macro avg       0.18      0.09      0.11     26812\n",
      "weighted avg       0.54      0.37      0.41     26812\n",
      " samples avg       0.94      0.38      0.49     26812\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "D:\\Program Files\\Anaconda3\\lib\\site-packages\\sklearn\\metrics\\_classification.py:1272: UndefinedMetricWarning: Recall and F-score are ill-defined and being set to 0.0 in labels with no true samples. Use `zero_division` parameter to control this behavior.\n",
      "  _warn_prf(average, modifier, msg_start, len(result))\n"
     ]
    }
   ],
   "source": [
    "cnb = OneVsRestClassifier(ComplementNB())\n",
    "cnb.fit(x_train,y_train)\n",
    "pred = cnb.predict(x_test)\n",
    "print(classification_report(y_test,pred))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
